{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "DuelingDQN.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "RaH44VKfCSFO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import random\n",
        "import numpy as np\n",
        "from collections import deque\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.autograd as autograd"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YwEWUBC8CofJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def mini_batch_train(env, agent, max_episodes, max_steps, batch_size):\n",
        "    episode_rewards = []\n",
        "\n",
        "    for episode in range(max_episodes):\n",
        "        state = env.reset()\n",
        "        episode_reward = 0\n",
        "\n",
        "        for step in range(max_steps):\n",
        "            action = agent.get_action(state)\n",
        "            next_state, reward, done, _ = env.step(action)\n",
        "            agent.replay_buffer.push(state, action, reward, next_state, done)\n",
        "            episode_reward += reward\n",
        "\n",
        "            if len(agent.replay_buffer) > batch_size:\n",
        "                agent.update(batch_size)   \n",
        "\n",
        "            if done or step == max_steps-1:\n",
        "                episode_rewards.append(episode_reward)\n",
        "                print(\"Episode \" + str(episode) + \": \" + str(episode_reward))\n",
        "                break\n",
        "\n",
        "            state = next_state\n",
        "\n",
        "    return episode_rewards"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Qm5u4V1tCIPO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "\n",
        "class BasicBuffer:\n",
        "\n",
        "    def __init__(self, max_size):\n",
        "        self.max_size = max_size\n",
        "        self.buffer = deque(maxlen=max_size)\n",
        "\n",
        "    def push(self, state, action, reward, next_state, done):\n",
        "        experience = (state, action, np.array([reward]), next_state, done)\n",
        "        self.buffer.append(experience)\n",
        "\n",
        "    def sample(self, batch_size):\n",
        "        state_batch = []\n",
        "        action_batch = []\n",
        "        reward_batch = []\n",
        "        next_state_batch = []\n",
        "        done_batch = []\n",
        "\n",
        "        batch = random.sample(self.buffer, batch_size)\n",
        "\n",
        "        for experience in batch:\n",
        "            state, action, reward, next_state, done = experience\n",
        "            state_batch.append(state)\n",
        "            action_batch.append(action)\n",
        "            reward_batch.append(reward)\n",
        "            next_state_batch.append(next_state)\n",
        "            done_batch.append(done)\n",
        "\n",
        "        return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)\n",
        "\n",
        "    def sample_sequence(self, batch_size):\n",
        "        state_batch = []\n",
        "        action_batch = []\n",
        "        reward_batch = []\n",
        "        next_state_batch = []\n",
        "        done_batch = []\n",
        "\n",
        "        min_start = len(self.buffer) - batch_size\n",
        "        start = np.random.randint(0, min_start)\n",
        "\n",
        "        for sample in range(start, start + batch_size):\n",
        "            state, action, reward, next_state, done = self.buffer[start]\n",
        "            state, action, reward, next_state, done = experience\n",
        "            state_batch.append(state)\n",
        "            action_batch.append(action)\n",
        "            reward_batch.append(reward)\n",
        "            next_state_batch.append(next_state)\n",
        "            done_batch.append(done)\n",
        "\n",
        "        return (state_batch, action_batch, reward_batch, next_state_batch, done_batch)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.buffer)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MrMnZS4YCru1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class ConvDuelingDQN(nn.Module):\n",
        "\n",
        "    def __init__(self, input_dim, output_dim):\n",
        "        super(ConvDuelingDQN, self).__init__()\n",
        "        self.input_dim = input_dim\n",
        "        self.output_dim = output_dim\n",
        "        self.fc_input_dim = self.feature_size()\n",
        "        \n",
        "        self.conv = nn.Sequential(\n",
        "            nn.Conv2d(input_dim[0], 32, kernel_size=8, stride=4),\n",
        "            nn.ReLU(),\n",
        "            nn.Conv2d(32, 64, kernel_size=4, stride=2),\n",
        "            nn.ReLU(),\n",
        "            nn.Conv2d(64, 64, kernel_size=3, stride=1),\n",
        "            nn.ReLU()\n",
        "        )\n",
        "\n",
        "        self.value_stream = nn.Sequential(\n",
        "            nn.Linear(self.fc_input_dim, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(128, 1)\n",
        "        )\n",
        "\n",
        "        self.advantage_stream = nn.Sequential(\n",
        "            nn.Linear(self.fc_input_dim, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(128, self.output_dim)\n",
        "        )\n",
        "\n",
        "    def forward(self, state):\n",
        "        features = self.conv(state)\n",
        "        features = features.view(features.size(0), -1)\n",
        "        values = self.value_stream(features)\n",
        "        advantages = self.advantage_stream(features)\n",
        "        qvals = values + (advantages - advantages.mean())\n",
        "        \n",
        "        return qvals\n",
        "\n",
        "    def feature_size(self):\n",
        "        return self.conv(autograd.Variable(torch.zeros(1, *self.input_dim))).view(1, -1).size(1)\n",
        "\n",
        "\n",
        "class DuelingDQN(nn.Module):\n",
        "\n",
        "    def __init__(self, input_dim, output_dim):\n",
        "        super(DuelingDQN, self).__init__()\n",
        "        self.input_dim = input_dim\n",
        "        self.output_dim = output_dim\n",
        "        \n",
        "        self.feauture_layer = nn.Sequential(\n",
        "            nn.Linear(self.input_dim[0], 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(128, 128),\n",
        "            nn.ReLU()\n",
        "        )\n",
        "\n",
        "        self.value_stream = nn.Sequential(\n",
        "            nn.Linear(128, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(128, 1)\n",
        "        )\n",
        "\n",
        "        self.advantage_stream = nn.Sequential(\n",
        "            nn.Linear(128, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(128, self.output_dim)\n",
        "        )\n",
        "\n",
        "    def forward(self, state):\n",
        "        features = self.feauture_layer(state)\n",
        "        values = self.value_stream(features)\n",
        "        advantages = self.advantage_stream(features)\n",
        "        qvals = values + (advantages - advantages.mean())\n",
        "        \n",
        "        return qvals"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kN4c4N8MCjYt",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class DuelingAgent:\n",
        "\n",
        "    def __init__(self, env, use_conv=True, learning_rate=3e-4, gamma=0.99, buffer_size=10000):\n",
        "        self.env = env\n",
        "        self.learning_rate = learning_rate\n",
        "        self.gamma = gamma\n",
        "        self.replay_buffer = BasicBuffer(max_size=buffer_size)\n",
        "   \n",
        "        self.device = \"cpu\"\n",
        "        if torch.cuda.is_available():\n",
        "            self.device = \"cuda\"\n",
        "\n",
        "        self.use_conv = use_conv        \n",
        "        if self.use_conv:\n",
        "            self.model = ConvDuelingDQN(env.observation_space.shape, env.action_space.n).to(self.device)\n",
        "        else:\n",
        "            self.model = DuelingDQN(env.observation_space.shape, env.action_space.n).to(self.device)\n",
        "        \n",
        "      \n",
        "        self.optimizer = torch.optim.Adam(self.model.parameters())\n",
        "        self.MSE_loss = nn.MSELoss()\n",
        "\n",
        "    def get_action(self, state, eps=0.20):\n",
        "        state = torch.FloatTensor(state).float().unsqueeze(0).to(self.device)\n",
        "        qvals = self.model.forward(state)\n",
        "        action = np.argmax(qvals.cpu().detach().numpy())\n",
        "        \n",
        "        if(np.random.randn() > eps):\n",
        "            return self.env.action_space.sample()\n",
        "        return action\n",
        "\n",
        "    def compute_loss(self, batch):\n",
        "        states, actions, rewards, next_states, dones = batch\n",
        "        states = torch.FloatTensor(states).to(self.device)\n",
        "        actions = torch.LongTensor(actions).to(self.device)\n",
        "        rewards = torch.FloatTensor(rewards).to(self.device)\n",
        "        next_states = torch.FloatTensor(next_states).to(self.device)\n",
        "        dones = torch.FloatTensor(dones).to(self.device)\n",
        "\n",
        "        curr_Q = self.model.forward(states).gather(1, actions.unsqueeze(1))\n",
        "        curr_Q = curr_Q.squeeze(1)\n",
        "        next_Q = self.model.forward(next_states)\n",
        "        max_next_Q = torch.max(next_Q, 1)[0]\n",
        "        expected_Q = rewards.squeeze(1) + self.gamma * max_next_Q\n",
        "\n",
        "        loss = self.MSE_loss(curr_Q, expected_Q)\n",
        "        \n",
        "        return loss\n",
        "\n",
        "    def update(self, batch_size):\n",
        "        batch = self.replay_buffer.sample(batch_size)\n",
        "        loss = self.compute_loss(batch)\n",
        "\n",
        "        self.optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        self.optimizer.step()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0YkJYcQxCxur",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 754
        },
        "outputId": "3baf12eb-8b5b-48b0-88d0-4d4c68fb413c"
      },
      "source": [
        "import gym\n",
        "\n",
        "\n",
        "env_id = \"CartPole-v0\"\n",
        "MAX_EPISODES = 1000\n",
        "MAX_STEPS = 500\n",
        "BATCH_SIZE = 32\n",
        "\n",
        "env = gym.make(env_id)\n",
        "agent = DuelingAgent(env, use_conv=False)\n",
        "episode_rewards = mini_batch_train(env, agent, MAX_EPISODES, MAX_STEPS, BATCH_SIZE)"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Episode 0: 25.0\n",
            "Episode 1: 15.0\n",
            "Episode 2: 20.0\n",
            "Episode 3: 12.0\n",
            "Episode 4: 36.0\n",
            "Episode 5: 47.0\n",
            "Episode 6: 16.0\n",
            "Episode 7: 11.0\n",
            "Episode 8: 71.0\n",
            "Episode 9: 46.0\n",
            "Episode 10: 28.0\n",
            "Episode 11: 47.0\n",
            "Episode 12: 152.0\n",
            "Episode 13: 14.0\n",
            "Episode 14: 18.0\n",
            "Episode 15: 24.0\n",
            "Episode 16: 13.0\n",
            "Episode 17: 22.0\n",
            "Episode 18: 18.0\n",
            "Episode 19: 21.0\n",
            "Episode 20: 34.0\n",
            "Episode 21: 21.0\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-8-ee8478f09546>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0menv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgym\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_id\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0magent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDuelingAgent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muse_conv\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mepisode_rewards\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmini_batch_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMAX_EPISODES\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMAX_STEPS\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mBATCH_SIZE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m<ipython-input-2-1d2ff31503ef>\u001b[0m in \u001b[0;36mmini_batch_train\u001b[0;34m(env, agent, max_episodes, max_steps, batch_size)\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m             \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m             \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m             \u001b[0magent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplay_buffer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpush\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-7-bcbd16c8dafc>\u001b[0m in \u001b[0;36mget_action\u001b[0;34m(self, state, eps)\u001b[0m\n\u001b[1;32m     23\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mget_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m         \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m         \u001b[0mqvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m         \u001b[0maction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqvals\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-4-5b3aa6d01c68>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, state)\u001b[0m\n\u001b[1;32m     68\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     69\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m         \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeauture_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     71\u001b[0m         \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue_stream\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     72\u001b[0m         \u001b[0madvantages\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madvantage_stream\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    491\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    492\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    494\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    495\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     90\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     93\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    491\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    492\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    494\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    495\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     90\u001b[0m     \u001b[0;34m@\u001b[0m\u001b[0mweak_script_method\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     93\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     94\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m   1404\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mbias\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1405\u001b[0m         \u001b[0;31m# fused op is marginally faster\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1406\u001b[0;31m         \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddmm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1407\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1408\u001b[0m         \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pdjQyL-sGQRr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}